import pandas as pd
import time
import json
import random
from tqdm import tqdm
import gin
from perturbation import perturbate_trajectory
import argparse

@gin.configurable
def process_row(row, news_df, summ_df, purturbed_summ_df, curr_summ_id, previous_doc_steps, decay_rate, perturbation_probability):
    """Process a single row using perturbate_trajectory."""
    try:
        purturbed_summ_df, curr_summ_id, updated_doc_list = perturbate_trajectory(
            row, news_df, summ_df, purturbed_summ_df, curr_summ_id,
            previous_doc_steps, decay_rate, perturbation_probability
        )
        return updated_doc_list, purturbed_summ_df, curr_summ_id

    except Exception as e:
        # Log the error without interrupting the process
        return row.get("Docs", "No Docs Found"), purturbed_summ_df, curr_summ_id

@gin.configurable
def main(args, decay_rate, previous_doc_steps, perturbation_probability, output_file_path, aug_output_file_path, final_output_file_path):
    
    print(f'decay_rate: {decay_rate}\nprevious_doc_steps:{previous_doc_steps}\nperturbation_probability:{perturbation_probability}\n')
    
    # Load data
    aug_df = pd.read_csv(args.aug_file)
    news_df = pd.read_csv(args.news_file, sep="\t")
    summ_df = pd.read_csv(args.summ_file, encoding='ISO-8859-1')

    news_df.set_index('News ID', inplace=True)
    summ_df.set_index('SummID', inplace=True)

    purturbed_summ_df = pd.DataFrame(columns=['SummID', 'NewsID', 'UserID', 'Summary'])

    aug_df = aug_df[:args.limit] if args.limit else aug_df

    curr_summ_id = 1
    start_time = time.time()

    # Create a copy of aug_df to store updates
    aug_df_loop = aug_df.copy()
    updated_aug_df = aug_df.copy()

    with tqdm(total=len(aug_df_loop), desc="Processing rows", unit="row") as pbar:
        for idx, (index, row) in enumerate(aug_df_loop.iterrows()):
            updated_docs, purturbed_summ_df, curr_summ_id = process_row(
                row, news_df, summ_df, purturbed_summ_df, curr_summ_id,
                previous_doc_steps, decay_rate, perturbation_probability
            )

            # Update the DataFrame
            updated_aug_df.at[index, 'Docs'] = updated_docs

            # Save progress for perturbed summaries and augmented DataFrame every 5 rows
            if (idx + 1) % 10000 == 0:
                try:
                    existing_df = pd.read_csv(output_file_path)
                    new_rows = purturbed_summ_df[~purturbed_summ_df.isin(existing_df)].dropna(how='all')
                    if not new_rows.empty:
                        new_rows.to_csv(output_file_path, index=False, mode='a', header=False)
                except FileNotFoundError:
                    purturbed_summ_df.to_csv(output_file_path, index=False, mode='w', header=True)

                updated_aug_df.iloc[:idx + 1].to_csv(aug_output_file_path, index=False, mode='w')

            pbar.update(1)

    updated_aug_df.to_csv(final_output_file_path, index=False)

    end_time = time.time()
    total_time = end_time - start_time
    print(f"Final updated DataFrame saved to {final_output_file_path}")
    print(f"Total time for processing: {total_time:.2f} seconds")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Process and perturb trajectories in augmented data.")
    parser.add_argument('--gin_file', type=str, required=True, help="Path to the gin configuration file.")
    parser.add_argument('--aug_file', type=str, required=True, help="Path to the augmented data file.")
    parser.add_argument('--news_file', type=str, required=True, help="Path to the news data file.")
    parser.add_argument('--summ_file', type=str, required=True, help="Path to the summary data file.")
    parser.add_argument('--limit', type=int, default=None, help="Limit the number of rows to process.")

    args = parser.parse_args()
    gin.parse_config_file(args.gin_file)
    main(args)


"""
Sample run
python driver.py \
    --gin_file='config.gin' \
    --aug_file='data/synthetic_dataset.csv' \
    --news_file='data/news_dataset.tsv' \
    --summ_file='data/summary_dataset.csv' \
    ;
"""